{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e7c7d0bb-8d45-4139-a584-02c7196db92b",
   "metadata": {},
   "source": [
    "# Find the best few-shot examples using simulation-based inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "831a76f5-c569-4174-adab-fb0245877367",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import random\n",
    "import requests\n",
    "import re\n",
    "\n",
    "import openai\n",
    "\n",
    "import outlines\n",
    "\n",
    "random.seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "ec604edc-c8b6-4088-bf17-b77ae57d05a1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: OPENAI_API_KEY=# your key here\n"
     ]
    }
   ],
   "source": [
    "%env OPENAI_API_KEY = # your key here"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aabb4db6-fd94-4c42-ab7f-97c3de45b2cc",
   "metadata": {},
   "source": [
    "In this example we will use GPT 4 mini to solve problems from the GSM-8K dataset. The state-of-the-art performance on this dataset is obtained using few-shot prompting with 5 examples. However, it is not clear how one should select these examples. Here, we will use **simulation-based inference** to try to infer which examples we should be using to get the best out of the model's abilities to solve the problem.\n",
    "\n",
    "Let's start with downloading the dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "367f5f89-8e5d-4381-b9eb-78c60bc50f86",
   "metadata": {},
   "outputs": [],
   "source": [
    "result = requests.get(\n",
    "    \"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl\"\n",
    ")\n",
    "lines = result.iter_lines()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef0f7aa9-d528-41e9-8a9d-4497f01f0692",
   "metadata": {},
   "source": [
    "We now divide the train set in two sets:\n",
    "- 20 problems from which we are going to sample 5 examples at random for every inference;\n",
    "- 500 problems which we are going to use to perform inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0667c4a8-cebe-4796-bbc9-575ee9498717",
   "metadata": {},
   "outputs": [],
   "source": [
    "example_set = []\n",
    "for _ in range(10):\n",
    "    line = json.loads(next(lines))\n",
    "    answer = re.findall(r\"\\d+\", line[\"answer\"])[-1]\n",
    "    example_set.append({\"question\": line[\"question\"], \"answer\": answer})\n",
    "\n",
    "train_set = []\n",
    "for _ in range(500):\n",
    "    line = json.loads(next(lines))\n",
    "    answer = re.findall(r\"\\d+\", line[\"answer\"])[-1]\n",
    "    train_set.append({\"question\": line[\"question\"], \"answer\": answer})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b52b470-d818-495a-a6e3-e50a1deff13c",
   "metadata": {},
   "source": [
    "Now let's define the prompt, the model, and the sampling loop. The sampling loop consists in choosing 5 examples at random, sampling 20 model answers; if the answer is correct we keep the example ids as samples, otherwise continue:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9fbebaa9-f05e-4c6b-8875-73a08273bbb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "few_shots = outlines.Template.from_string(\n",
    "    \"\"\"\n",
    "    {% for example in examples %}\n",
    "    Q: {{ example.question }}\n",
    "    A: {{ example.answer }}\n",
    "    {% endfor %}\n",
    "    Q: {{ question }}\n",
    "    A:\n",
    "    \"\"\"\n",
    ")\n",
    "\n",
    "model = outlines.from_openai(openai.OpenAI(), \"gpt-4o-mini\")\n",
    "\n",
    "\n",
    "# TODO: This could largely benefit from vectorization in #52\n",
    "def one_train_example(problem, example_set):\n",
    "    example_ids = random.choices(range(0, len(example_set)), k=5)\n",
    "    examples = [example_set[i] for i in example_ids]\n",
    "    prompt = few_shots(question=problem[\"question\"], examples=examples)\n",
    "    answers_raw = model(prompt, samples=20)\n",
    "\n",
    "    samples = []\n",
    "    for answer_raw in answers_raw:\n",
    "        try:\n",
    "            answer = re.findall(r\"\\d+\", answer_raw)[-1]\n",
    "            if answer == problem[\"answer\"]:\n",
    "                samples += example_ids\n",
    "            else:\n",
    "                continue\n",
    "        except IndexError:\n",
    "            pass\n",
    "\n",
    "    return samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1dae1ef2-c9e0-4c98-8686-7fbc2ff55e56",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9efc9d077af24a2eb5ea3c05fe63f298",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/500 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tqdm.contrib import tenumerate\n",
    "\n",
    "samples = []\n",
    "for i, train_example in tenumerate(train_set):\n",
    "    samples += one_train_example(train_example, example_set)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d084a3c1-ec84-48ab-91d7-b4c64afce512",
   "metadata": {},
   "source": [
    "We can now plot the number of times each example was sampled:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c211e0af-33c6-4e35-af83-194dcef31f01",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0, 0.5, 'Counts')"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAtoAAAHgCAYAAACb58plAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbiElEQVR4nO3da7BldXnn8d8jLV7wAmqHIsBMk5FxtIyJpEW8jhFjQBxxLEQdo5SFYVKjxktubZIqy5kXg1OWGmdSZFBUrFHUoJZEGJRBvE2NaItEFHTsIEozKK0i3pIo+syLszDHtpHT0P+9zuXzqTp19l5r7X2ebRfdX9f5772quwMAAOxbd5p7AAAAWI+ENgAADCC0AQBgAKENAAADCG0AABhAaAMAwACb5h5ghOOOO64vvPDCuccAAGD9q1vbsS7PaH/jG9+YewQAADa4dRnaAAAwt2GhXVVvqqobqupzy7bdp6ouqqovTd8PmrZXVb2+qnZU1Wer6qhljzllOv5LVXXKqHkBAGBfGnlG+y1Jjttt27YkF3f3kUkunu4nyfFJjpy+TktyRrIU5klekeThSY5O8opb4hwAAFazYaHd3R9N8q3dNp+Y5Ozp9tlJnrps+1t7ySeSHFhVhyT57SQXdfe3uvvGJBfl5+MdAABWnUWv0T64u6+fbn8tycHT7UOTXLvsuJ3Ttlvb/nOq6rSq2l5V23ft2rVvpwYAgL0025shu7uT9D58vjO7e2t3b928efO+eloAALhdFh3aX5+WhGT6fsO0/bokhy877rBp261tBwCAVW3RoX1ekls+OeSUJO9btv2506ePHJPkpmmJyQeSPLGqDpreBPnEaRsAAKxqw64MWVXnJHlckvtV1c4sfXrI6UneVVWnJvlKkpOnwy9I8qQkO5L8IMnzkqS7v1VV/ynJp6bj/mN37/4GSwAAWHVqaan0+rJ169bevn373GMAALD+baxLsAMAwNyENgAADCC0AQBgAKENAAADCG0AABhAaAMAwABCGwAABhDaAAAwgNAGAIABhl2CHQCA1W/LtvPnHuEOu+b0E+YeYY+c0QYAgAGENgAADCC0AQBgAKENAAADCG0AABhAaAMAwABCGwAABhDaAAAwgNAGAIABhDYAAAwgtAEAYAChDQAAAwhtAAAYQGgDAMAAQhsAAAYQ2gAAMIDQBgCAAYQ2AAAMILQBAGAAoQ0AAAMIbQAAGEBoAwDAAEIbAAAGENoAADCA0AYAgAGENgAADCC0AQBgAKENAAADCG0AABhAaAMAwABCGwAABhDaAAAwgNAGAIABhDYAAAwgtAEAYAChDQAAAwhtAAAYQGgDAMAAm+YeANayLdvOn3uEfeKa00+YewQAWHec0QYAgAGENgAADCC0AQBgAKENAAADCG0AABhAaAMAwABCGwAABhDaAAAwgNAGAIABhDYAAAwgtAEAYAChDQAAAwhtAAAYQGgDAMAAQhsAAAYQ2gAAMIDQBgCAAYQ2AAAMILQBAGCATXMPwPqwZdv5c4+wT1xz+glzjwAArBPOaAMAwABCGwAABhDaAAAwgNAGAIABhDYAAAwgtAEAYAAf7wcAe+BjS4E7apYz2lX10qr6fFV9rqrOqaq7VtURVXVpVe2oqndW1f7TsXeZ7u+Y9m+ZY2YAANgbCw/tqjo0ye8n2drdD06yX5JnJnlVktd29/2T3Jjk1Okhpya5cdr+2uk4AABY1eZao70pyd2qalOSuye5Psnjk5w77T87yVOn2ydO9zPtP7aqanGjAgDA3lt4aHf3dUleneSrWQrsm5J8Osm3u/vm6bCdSQ6dbh+a5NrpsTdPx993kTMDAMDemmPpyEFZOkt9RJJfTnJAkuP2wfOeVlXbq2r7rl277ujTAQDAHTLH0pEnJPlyd+/q7h8leU+SRyU5cFpKkiSHJbluun1dksOTZNp/7yTf3P1Ju/vM7t7a3Vs3b948+jUAAMAvNEdofzXJMVV192mt9bFJrkxySZKTpmNOSfK+6fZ50/1M+z/U3b3AeQEAYK/NsUb70iy9qfGyJFdMM5yZ5E+SvKyqdmRpDfZZ00POSnLfafvLkmxb9MwAALC3ZrlgTXe/Iskrdtt8dZKj93DsPyR5+iLmAgCAfcUl2AEAYAChDQAAAwhtAAAYQGgDAMAAQhsAAAYQ2gAAMIDQBgCAAYQ2AAAMILQBAGAAoQ0AAAMIbQAAGGDT3AOsN1u2nT/3CHfYNaefMPcIAABrnjPaAAAwgNAGAIABhDYAAAwgtAEAYAChDQAAAwhtAAAYQGgDAMAAQhsAAAYQ2gAAMIDQBgCAAYQ2AAAMILQBAGAAoQ0AAAMIbQAAGEBoAwDAAEIbAAAGENoAADCA0AYAgAGENgAADLBp7gEAWN22bDt/7hH2iWtOP2HuEdaM9fBn7s+b1cAZbQAAGEBoAwDAAEIbAAAGsEYb2GvrYf1mYg0nAGM5ow0AAAMIbQAAGEBoAwDAAEIbAAAGENoAADCA0AYAgAGENgAADCC0AQBgAKENAAADCG0AABhAaAMAwABCGwAABhDaAAAwgNAGAIABhDYAAAwgtAEAYAChDQAAAwhtAAAYYNPcAwAArAZbtp0/9wh32DWnnzD3CCzjjDYAAAwgtAEAYAChDQAAAwhtAAAYQGgDAMAAQhsAAAYQ2gAAMIDQBgCAAYQ2AAAMILQBAGAAoQ0AAAMIbQAAGEBoAwDAAEIbAAAGENoAADCA0AYAgAGENgAADCC0AQBgAKENAAADCG0AABhAaAMAwABCGwAABpgltKvqwKo6t6q+UFVXVdUjquo+VXVRVX1p+n7QdGxV1eurakdVfbaqjppjZgAA2BtzndH+iyQXdve/SvJrSa5Ksi3Jxd19ZJKLp/tJcnySI6ev05KcsfhxAQBg7yw8tKvq3kkem+SsJOnuH3b3t5OcmOTs6bCzkzx1un1ikrf2kk8kObCqDlno0AAAsJfmOKN9RJJdSd5cVZ+pqjdW1QFJDu7u66djvpbk4On2oUmuXfb4ndM2AABYteYI7U1JjkpyRnc/NMn380/LRJIk3d1Jem+etKpOq6rtVbV9165d+2xYAAC4PeYI7Z1Jdnb3pdP9c7MU3l+/ZUnI9P2Gaf91SQ5f9vjDpm0/o7vP7O6t3b118+bNw4YHAICVWHhod/fXklxbVQ+YNh2b5Mok5yU5Zdp2SpL3TbfPS/Lc6dNHjkly07IlJgAAsCptmunnvijJ26pq/yRXJ3lelqL/XVV1apKvJDl5OvaCJE9KsiPJD6ZjAQBgVZsltLv78iRb97Dr2D0c20leMHomAADYl1wZEgAABhDaAAAwgNAGAIABhDYAAAwgtAEAYAChDQAAAwhtAAAYQGgDAMAAQhsAAAYQ2gAAMIDQBgCAAYQ2AAAMILQBAGAAoQ0AAAMIbQAAGEBoAwDAAEIbAAAGENoAADCA0AYAgAH2OrSr6qCqesiIYQAAYL1YUWhX1Yer6l5VdZ8klyV5Q1W9ZuxoAACwdq30jPa9u/s7SZ6W5K3d/fAkTxg3FgAArG0rDe1NVXVIkpOTvH/gPAAAsC6sNLRfmeQDSXZ096eq6leSfGncWAAAsLZtWuFx13f3T98A2d1XW6MNAAC3bqVntP/rCrcBAAC5jTPaVfWIJI9MsrmqXrZs172S7DdyMAAAWMtua+nI/knuMR13z2Xbv5PkpFFDAQDAWvcLQ7u7P5LkI1X1lu7+yoJmAgCANW+lb4a8S1WdmWTL8sd09+NHDAUAAGvdSkP7r5P8VZI3JvnxuHEAAGB9WGlo39zdZwydBAAA1pGVfrzf31TVf6iqQ6rqPrd8DZ0MAADWsJWe0T5l+v5Hy7Z1kl/Zt+MAAMD6sKLQ7u4jRg8CAADryYpCu6qeu6ft3f3WfTsOAACsDytdOvKwZbfvmuTYJJclEdoAALAHK1068qLl96vqwCTvGDEQAACsByv91JHdfT+JddsAAHArVrpG+2+y9CkjSbJfkgcmedeooQAAYK1b6RrtVy+7fXOSr3T3zgHzAADAurCipSPd/ZEkX0hyzyQHJfnhyKEAAGCtW1FoV9XJST6Z5OlJTk5yaVWdNHIwAABYy1a6dOTPkjysu29IkqranOR/JTl31GAAALCWrfRTR+50S2RPvrkXjwUAgA1npWe0L6yqDyQ5Z7r/jCQXjBkJAADWvl8Y2lV1/yQHd/cfVdXTkjx62vV/krxt9HAAALBW3dYZ7dcleXmSdPd7krwnSarqV6d9/2bgbACrypZt5889wj5xzeknzD0CwIZwW+usD+7uK3bfOG3bMmQiAABYB24rtA/8Bfvutg/nAACAdeW2Qnt7Vf3u7hur6vlJPj1mJAAAWPtua432S5K8t6qenX8K661J9k/ybwfOBQAAa9ovDO3u/nqSR1bVbyZ58LT5/O7+0PDJAABgDVvR52h39yVJLhk8CwAArBuu7ggAAAMIbQAAGEBoAwDAAEIbAAAGENoAADCA0AYAgAGENgAADCC0AQBgAKENAAADCG0AABhAaAMAwABCGwAABhDaAAAwgNAGAIABhDYAAAwgtAEAYAChDQAAAwhtAAAYQGgDAMAAQhsAAAYQ2gAAMIDQBgCAAYQ2AAAMMFtoV9V+VfWZqnr/dP+Iqrq0qnZU1Turav9p+12m+zum/VvmmhkAAFZqzjPaL05y1bL7r0ry2u6+f5Ibk5w6bT81yY3T9tdOxwEAwKo2S2hX1WFJTkjyxul+JXl8knOnQ85O8tTp9onT/Uz7j52OBwCAVWuuM9qvS/LHSX4y3b9vkm93983T/Z1JDp1uH5rk2iSZ9t80Hf8zquq0qtpeVdt37do1cHQAALhtCw/tqnpykhu6+9P78nm7+8zu3trdWzdv3rwvnxoAAPbaphl+5qOSPKWqnpTkrknuleQvkhxYVZums9aHJbluOv66JIcn2VlVm5LcO8k3Fz82AACs3MLPaHf3y7v7sO7ekuSZST7U3c9OckmSk6bDTknyvun2edP9TPs/1N29wJEBAGCvrabP0f6TJC+rqh1ZWoN91rT9rCT3nba/LMm2meYDAIAVm2PpyE9194eTfHi6fXWSo/dwzD8kefpCBwMAgDtoNZ3RBgCAdUNoAwDAAEIbAAAGENoAADCA0AYAgAGENgAADCC0AQBgAKENAAADCG0AABhAaAMAwABCGwAABhDaAAAwgNAGAIABhDYAAAwgtAEAYAChDQAAAwhtAAAYQGgDAMAAQhsAAAYQ2gAAMIDQBgCAAYQ2AAAMILQBAGAAoQ0AAAMIbQAAGEBoAwDAAEIbAAAGENoAADCA0AYAgAGENgAADCC0AQBgAKENAAADCG0AABhAaAMAwABCGwAABhDaAAAwgNAGAIABhDYAAAwgtAEAYAChDQAAAwhtAAAYQGgDAMAAQhsAAAYQ2gAAMIDQBgCAAYQ2AAAMILQBAGAAoQ0AAAMIbQAAGEBoAwDAAEIbAAAGENoAADCA0AYAgAGENgAADCC0AQBgAKENAAADCG0AABhAaAMAwABCGwAABhDaAAAwgNAGAIABhDYAAAwgtAEAYAChDQAAAwhtAAAYQGgDAMAAQhsAAAYQ2gAAMIDQBgCAAYQ2AAAMILQBAGAAoQ0AAAMIbQAAGEBoAwDAAEIbAAAGENoAADDAwkO7qg6vqkuq6sqq+nxVvXjafp+quqiqvjR9P2jaXlX1+qraUVWfraqjFj0zAADsrTnOaN+c5A+6+0FJjknygqp6UJJtSS7u7iOTXDzdT5Ljkxw5fZ2W5IzFjwwAAHtn4aHd3dd392XT7e8muSrJoUlOTHL2dNjZSZ463T4xyVt7ySeSHFhVhyx2agAA2DuzrtGuqi1JHprk0iQHd/f1066vJTl4un1okmuXPWzntA0AAFat2UK7qu6R5N1JXtLd31m+r7s7Se/l851WVduravuuXbv24aQAALD3ZgntqrpzliL7bd39nmnz129ZEjJ9v2Hafl2Sw5c9/LBp28/o7jO7e2t3b928efO44QEAYAXm+NSRSnJWkqu6+zXLdp2X5JTp9ilJ3rds+3OnTx85JslNy5aYAADAqrRphp/5qCTPSXJFVV0+bfvTJKcneVdVnZrkK0lOnvZdkORJSXYk+UGS5y10WgAAuB0WHtrd/fEkdSu7j93D8Z3kBUOHAgCAfcyVIQEAYAChDQAAAwhtAAAYQGgDAMAAQhsAAAYQ2gAAMIDQBgCAAYQ2AAAMILQBAGAAoQ0AAAMIbQAAGEBoAwDAAEIbAAAGENoAADCA0AYAgAGENgAADCC0AQBgAKENAAADCG0AABhAaAMAwABCGwAABhDaAAAwgNAGAIABhDYAAAwgtAEAYAChDQAAAwhtAAAYQGgDAMAAQhsAAAYQ2gAAMIDQBgCAAYQ2AAAMILQBAGAAoQ0AAAMIbQAAGEBoAwDAAEIbAAAGENoAADCA0AYAgAGENgAADCC0AQBgAKENAAADCG0AABhAaAMAwABCGwAABhDaAAAwgNAGAIABhDYAAAwgtAEAYAChDQAAAwhtAAAYQGgDAMAAQhsAAAYQ2gAAMIDQBgCAAYQ2AAAMILQBAGAAoQ0AAAMIbQAAGEBoAwDAAEIbAAAGENoAADCA0AYAgAGENgAADCC0AQBgAKENAAADCG0AABhAaAMAwABCGwAABhDaAAAwgNAGAIABhDYAAAwgtAEAYAChDQAAAwhtAAAYQGgDAMAAQhsAAAZYM6FdVcdV1RerakdVbZt7HgAA+EXWRGhX1X5J/jLJ8UkelORZVfWgeacCAIBbtyZCO8nRSXZ099Xd/cMk70hy4swzAQDArVoroX1okmuX3d85bQMAgFWpunvuGW5TVZ2U5Ljufv50/zlJHt7dL1x2zGlJTpvuPiDJFxc+6OLcL8k35h5iBl73xuJ1byxe98ayUV93snFf+3p+3d/o7uP2tGPToie5na5Lcviy+4dN236qu89McuYih5pLVW3v7q1zz7FoXvfG4nVvLF73xrJRX3eycV/7Rn3da2XpyKeSHFlVR1TV/kmemeS8mWcCAIBbtSbOaHf3zVX1wiQfSLJfkjd19+dnHgsAAG7VmgjtJOnuC5JcMPccq8SGWCKzB173xuJ1byxe98ayUV93snFf+4Z83WvizZAAALDWrJU12gAAsKYI7TVko16GvqreVFU3VNXn5p5lUarq8Kq6pKqurKrPV9WL555pUarqrlX1yar62+m1v3LumRalqvarqs9U1fvnnmWRquqaqrqiqi6vqu1zz7MoVXVgVZ1bVV+oqquq6hFzzzRaVT1g+nO+5es7VfWSuedahKp66fR32ueq6pyquuvcMy1CVb14es2f3yh/1stZOrJGTJeh/79JfitLF+z5VJJndfeVsw62AFX12CTfS/LW7n7w3PMsQlUdkuSQ7r6squ6Z5NNJnrpB/rwryQHd/b2qunOSjyd5cXd/YubRhquqlyXZmuRe3f3kuedZlKq6JsnW7l6vn7G7R1V1dpKPdfcbp0/Uunt3f3vmsRZm+nftuixdF+Mrc88zUlUdmqW/yx7U3X9fVe9KckF3v2Xeycaqqgdn6WreRyf5YZILk/xed++YdbAFckZ77diwl6Hv7o8m+dbccyxSd1/f3ZdNt7+b5KpskKuh9pLvTXfvPH2t+zMCVXVYkhOSvHHuWRivqu6d5LFJzkqS7v7hRorsybFJ/m69R/Yym5Lcrao2Jbl7kv838zyL8MAkl3b3D7r75iQfSfK0mWdaKKG9drgM/QZVVVuSPDTJpTOPsjDTEorLk9yQ5KLu3giv/XVJ/jjJT2aeYw6d5INV9enpKr8bwRFJdiV587Rc6I1VdcDcQy3YM5OcM/cQi9Dd1yV5dZKvJrk+yU3d/cF5p1qIzyV5TFXdt6runuRJ+dkLEK57QhtWsaq6R5J3J3lJd39n7nkWpbt/3N2/nqWrwB49/fpx3aqqJye5obs/PfcsM3l0dx+V5PgkL5iWi613m5IcleSM7n5oku8n2Ujvvdk/yVOS/PXcsyxCVR2Upd9CH5Hkl5McUFW/M+9U43X3VUleleSDWVo2cnmSH88506IJ7bXjNi9Dz/oyrU9+d5K3dfd75p5nDtOv0i9JctzMo4z2qCRPmdYqvyPJ46vqf8w70uJMZ/vS3TckeW+WlsqtdzuT7Fz225pzsxTeG8XxSS7r7q/PPciCPCHJl7t7V3f/KMl7kjxy5pkWorvP6u7f6O7HJrkxS+832zCE9trhMvQbyPSGwLOSXNXdr5l7nkWqqs1VdeB0+25ZegPwF2YdarDufnl3H9bdW7L03/aHunvdn+1Kkqo6YHrDb6alE0/M0q+b17Xu/lqSa6vqAdOmY5Os+zc7L/OsbJBlI5OvJjmmqu4+/f1+bJbee7PuVdUvTd//WZbWZ7993okWa81cGXKj28iXoa+qc5I8Lsn9qmpnkld091nzTjXco5I8J8kV01rlJPnT6Qqp690hSc6ePpHgTkne1d0b6uPuNpiDk7x3qT2yKcnbu/vCeUdamBcledt08uTqJM+beZ6FmP4P1W8l+fdzz7Io3X1pVZ2b5LIkNyf5TDbOlRLfXVX3TfKjJC/YaG/69fF+AAAwgKUjAAAwgNAGAIABhDYAAAwgtAEAYAChDQAAAwhtgFWuqn5cVZcv+5rlCoJVdU1V3e92PO63q+qVVXWfqvqfI2YDWI18jjbA6vf30yXp16rHZOkKn49J8vGZZwFYGGe0Adagqrp3VX3xlisLVtU5VfW70+0zqmp7VX2+ql657DHXVNV/ns6Kb6+qo6rqA1X1d1X1e9Mxj6uqj1bV+dPz/1VV/dy/FVX1O1X1yem5/vt0gaHdj3nGdMGl30/yuiRvSPK8qnJVW2BDENoAq9/ddls68ozuvinJC5O8paqemeSg7n7DdPyfdffWJA9J8q+r6iHLnuur09nxjyV5S5KTkhyT5JXLjjk6S1ctfFCSf5Glyyb/VFU9MMkzkjxqeq4fJ3n27kN39zuTPDTJ57r7V5NckeSh3f2U2/8/BcDaYekIwOq3x6Uj3X1RVT09yV8m+bVlu06uqtOy9Hf8IVkK5s9O+245m3xFknt093eTfLeq/rGqDpz2fbK7r06WzpQneXSSc5c9/7FJfiPJp6ZLp98tyQ23Mvu/zNLlxZPkgOnnAWwIQhtgjZqWdDwwyQ+SHJRkZ1UdkeQPkzysu2+sqrckueuyh/3j9P0ny27fcv+WfxN6tx+1+/1KcnZ3v/w25tue5H5JNlXVlUkOmZaSvKi7P3bbrxBgbbN0BGDtemmSq5L8uyRvrqo7J7lXku8nuamqDk5y/O143qOr6ogp5J+Rn38D48VJTqqqX0qS6dNE/vnuTzItXzk/yYlJ/kuWlrT8usgGNgqhDbD67b5G+/TpTZDPT/IHU7h+NMmfd/ffJvlMki8keXuS/307ft6nkvy3LEX8l5O8d/nO7r4yyZ8n+WBVfTbJRVlaorInRyW5PEufOPKR2zELwJpV3bv/RhCAjaqqHpfkD7v7yTOPArDmOaMNAAADOKMNAAADOKMNAAADCG0AABhAaAMAwABCGwAABhDaAAAwgNAGAIAB/j/+cpD1bixMkQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 864x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pylab as plt\n",
    "\n",
    "example_ids, counts = np.unique(samples, return_counts=True)\n",
    "\n",
    "fig = plt.figure(figsize=(12, 8))\n",
    "ax = fig.add_subplot(111)\n",
    "ax.bar(example_ids, counts)\n",
    "\n",
    "ax.spines[[\"top\", \"right\"]].set_visible(False)\n",
    "\n",
    "ax.set_xticks(range(10))\n",
    "ax.set_xlabel(\"Example #\")\n",
    "ax.set_ylabel(\"Counts\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cde37e5b-377e-4872-af40-674d680bd2da",
   "metadata": {},
   "source": [
    "Looking at the distribution, our best guess for which examples we should use for benchmarking on the test set would be 0, 1, 2, 6 and 9. This method can be trivially extended to other workflows that use few-shot examples to query LLMs. Of course, simulation-based inference extends beyong choosing the \"best\" prompt, and could for instance be useful to select the structure of chains of LLMs and tools as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "bddda20b-234a-4d30-b40a-90708fbaba23",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',\n",
       " 'answer': '72'}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example_set[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "fb186bf9-62b7-485f-a8ce-401f551a9e57",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'question': 'Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?',\n",
       " 'answer': '10'}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example_set[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ae427bb2-e3f4-4a96-a508-e8011a0fc553",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'question': 'Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?',\n",
       " 'answer': '5'}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example_set[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "fe43ae0f-c18f-4b74-b639-8481472edf4d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'question': 'Albert is wondering how much pizza he can eat in one day. He buys 2 large pizzas and 2 small pizzas. A large pizza has 16 slices and a small pizza has 8 slices. If he eats it all, how many pieces does he eat that day?',\n",
       " 'answer': '48'}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example_set[6]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "19d9d936-d0f0-4927-990c-76dbbfa95b47",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'question': 'Tina makes $18.00 an hour.  If she works more than 8 hours per shift, she is eligible for overtime, which is paid by your hourly wage + 1/2 your hourly wage.  If she works 10 hours every day for 5 days, how much money does she make?',\n",
       " 'answer': '990'}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example_set[9]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
